{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# prerequisites\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "from torch.autograd import Variable\n",
    "from torchvision.utils import save_image\n",
    "\n",
    "# Device configuration\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 100\n",
    "\n",
    "# MNIST Dataset\n",
    "transform = transforms.Compose([\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Lambda(lambda x: x.repeat(3,1,1)),\n",
    "    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])\n",
    "\n",
    "train_dataset = datasets.MNIST(root='./', train=True, transform=transform, download=True)\n",
    "test_dataset = datasets.MNIST(root='./', train=False, transform=transform, download=False)\n",
    "\n",
    "# Data Loader (Input Pipeline)\n",
    "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Generator(nn.Module):\n",
    "    def __init__(self, g_input_dim, g_output_dim):\n",
    "        super(Generator, self).__init__()       \n",
    "        self.fc1 = nn.Linear(g_input_dim, 256)\n",
    "        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)\n",
    "        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)\n",
    "        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)\n",
    "    \n",
    "    # forward method\n",
    "    def forward(self, x): \n",
    "        x = F.leaky_relu(self.fc1(x), 0.2)\n",
    "        x = F.leaky_relu(self.fc2(x), 0.2)\n",
    "        x = F.leaky_relu(self.fc3(x), 0.2)\n",
    "        return torch.tanh(self.fc4(x))\n",
    "    \n",
    "class Discriminator(nn.Module):\n",
    "    def __init__(self, d_input_dim):\n",
    "        super(Discriminator, self).__init__()\n",
    "        self.fc1 = nn.Linear(d_input_dim, 1024)\n",
    "        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)\n",
    "        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)\n",
    "        self.fc4 = nn.Linear(self.fc3.out_features, 1)\n",
    "    \n",
    "    # forward method\n",
    "    def forward(self, x):\n",
    "        x = F.leaky_relu(self.fc1(x), 0.2)\n",
    "        x = F.dropout(x, 0.3)\n",
    "        x = F.leaky_relu(self.fc2(x), 0.2)\n",
    "        x = F.dropout(x, 0.3)\n",
    "        x = F.leaky_relu(self.fc3(x), 0.2)\n",
    "        x = F.dropout(x, 0.3)\n",
    "        return torch.sigmoid(self.fc4(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/syno/anaconda3/lib/python3.7/site-packages/torchvision/datasets/mnist.py:53: UserWarning: train_data has been renamed data\n",
      "  warnings.warn(\"train_data has been renamed data\")\n"
     ]
    }
   ],
   "source": [
    "# build network\n",
    "z_dim = 100\n",
    "mnist_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2)\n",
    "\n",
    "G = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)\n",
    "D = Discriminator(mnist_dim).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Generator(\n",
       "  (fc1): Linear(in_features=100, out_features=256, bias=True)\n",
       "  (fc2): Linear(in_features=256, out_features=512, bias=True)\n",
       "  (fc3): Linear(in_features=512, out_features=1024, bias=True)\n",
       "  (fc4): Linear(in_features=1024, out_features=784, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "G"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Discriminator(\n",
       "  (fc1): Linear(in_features=784, out_features=1024, bias=True)\n",
       "  (fc2): Linear(in_features=1024, out_features=512, bias=True)\n",
       "  (fc3): Linear(in_features=512, out_features=256, bias=True)\n",
       "  (fc4): Linear(in_features=256, out_features=1, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# loss\n",
    "criterion = nn.BCELoss() \n",
    "\n",
    "# optimizer\n",
    "lr = 0.0002 \n",
    "G_optimizer = optim.Adam(G.parameters(), lr = lr)\n",
    "D_optimizer = optim.Adam(D.parameters(), lr = lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def D_train(x):\n",
    "    #=======================Train the discriminator=======================#\n",
    "    D.zero_grad()\n",
    "\n",
    "    # train discriminator on real\n",
    "    x_real, y_real = x.view(-1, mnist_dim), torch.ones(bs, 1)\n",
    "    x_real, y_real = Variable(x_real.to(device)), Variable(y_real.to(device))\n",
    "\n",
    "    D_output = D(x_real)\n",
    "    D_real_loss = criterion(D_output, y_real)\n",
    "    D_real_score = D_output\n",
    "\n",
    "    # train discriminator on facke\n",
    "    z = Variable(torch.randn(bs, z_dim).to(device))\n",
    "    x_fake, y_fake = G(z), Variable(torch.zeros(bs, 1).to(device))\n",
    "\n",
    "    D_output = D(x_fake)\n",
    "    D_fake_loss = criterion(D_output, y_fake)\n",
    "    D_fake_score = D_output\n",
    "\n",
    "    # gradient backprop & optimize ONLY D's parameters\n",
    "    D_loss = D_real_loss + D_fake_loss\n",
    "    D_loss.backward()\n",
    "    D_optimizer.step()\n",
    "        \n",
    "    return  D_loss.data.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def G_train(x):\n",
    "    #=======================Train the generator=======================#\n",
    "    G.zero_grad()\n",
    "\n",
    "    z = Variable(torch.randn(bs, z_dim).to(device))\n",
    "    y = Variable(torch.ones(bs, 1).to(device))\n",
    "\n",
    "    G_output = G(z)\n",
    "    D_output = D(G_output)\n",
    "    G_loss = criterion(D_output, y)\n",
    "\n",
    "    # gradient backprop & optimize ONLY G's parameters\n",
    "    G_loss.backward()\n",
    "    G_optimizer.step()\n",
    "        \n",
    "    return G_loss.data.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/syno/anaconda3/lib/python3.7/site-packages/torch/nn/modules/loss.py:498: UserWarning: Using a target size (torch.Size([100, 1])) that is different to the input size (torch.Size([300, 1])) is deprecated. Please ensure they have the same size.\n",
      "  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)\n"
     ]
    },
    {
     "ename": "ValueError",
     "evalue": "Target and input must have the same number of elements. target nelement (100) != input nelement (300)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-10-4bbd97c7e557>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      3\u001b[0m     \u001b[0mD_losses\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mG_losses\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m         \u001b[0mD_losses\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mD_train\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      6\u001b[0m         \u001b[0mG_losses\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mG_train\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-8-7afdcf929475>\u001b[0m in \u001b[0;36mD_train\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      9\u001b[0m     \u001b[0mD_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mD\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_real\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m     \u001b[0mD_real_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mD_output\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_real\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     11\u001b[0m     \u001b[0mD_real_score\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mD_output\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    539\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    540\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 541\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    542\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    543\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/modules/loss.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input, target)\u001b[0m\n\u001b[1;32m    496\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    497\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 498\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbinary_cross_entropy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduction\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreduction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    499\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    500\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mbinary_cross_entropy\u001b[0;34m(input, target, weight, size_average, reduce, reduction)\u001b[0m\n\u001b[1;32m   2056\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2057\u001b[0m         raise ValueError(\"Target and input must have the same number of elements. target nelement ({}) \"\n\u001b[0;32m-> 2058\u001b[0;31m                          \"!= input nelement ({})\".format(target.numel(), input.numel()))\n\u001b[0m\u001b[1;32m   2059\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2060\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mweight\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: Target and input must have the same number of elements. target nelement (100) != input nelement (300)"
     ]
    }
   ],
   "source": [
    "n_epoch = 200\n",
    "for epoch in range(1, n_epoch+1):           \n",
    "    D_losses, G_losses = [], []\n",
    "    for batch_idx, (x, _) in enumerate(train_loader):\n",
    "        D_losses.append(D_train(x))\n",
    "        G_losses.append(G_train(x))\n",
    "\n",
    "    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (\n",
    "            (epoch), n_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    test_z = Variable(torch.randn(bs, z_dim).to(device))\n",
    "    generated = G(test_z)\n",
    "\n",
    "    save_image(generated.view(generated.size(0), 1, 28, 28), './samples/sample_' + '.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
